import time
import os
import itertools

import networkx as nx
import networkx.algorithms.community as nx_comm

from netgraph import Graph
from netgraph import ArcDiagram

from sknetwork.clustering import Louvain

import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors

import gseapy as gp
from gseapy import barplot, dotplot, Biomart

import somadata

sns.set_theme(style="whitegrid")

def coloring_communities(cluster, color_list):
    node_labels = dict()
    colors = []
    node_color = dict()
    node_community = dict()
    node = 0
    i = 0
    l = len(cluster)
    if color_list == None:
        for com in cluster:
            for c in com:
                # colors.append(col)
                # node_color[c] = col
                node_community[c] = i
                node_labels[c] = c
                node+=1
            i+=1
    else:
        
        for com, col in zip(cluster, color_list[:l]):
            for c in com:
                colors.append(col)
                node_color[c] = col
                node_community[c] = i
                node_labels[c] = c
                node+=1
            i+=1
    return node_color, node_labels, node_community

def get_positive_corr_edges(G):
    g = G.copy()
    i =0 
    del_edges = []
    for u,v,w in g.edges(data=True):
        # if w['weight']==-1:
        if w['weight']>=0:
            del_edges.append((u,v))
    g.remove_edges_from(del_edges)
    return g

def get_negative_corr_edges(G):
    g = G.copy()
    i =0 
    del_edges = []
    for u,v,w in g.edges(data=True):
        # if w['weight']==1:
        if w['weight']<0:
            del_edges.append((u,v))
    g.remove_edges_from(del_edges)
    return g

def graph_diff(G1, G2):
    """Finds the difference between two graphs, even with different node sets."""

    # Find nodes present only in G1
    nodes_only_in_G1 = set(G1.nodes) - set(G2.nodes)
    G1_diff = G1.subgraph(nodes_only_in_G1)

    # Find nodes present only in G2
    nodes_only_in_G2 = set(G2.nodes) - set(G1.nodes)
    G2_diff = G2.subgraph(nodes_only_in_G2)

    # Find the common subgraph
    common_nodes = set(G1.nodes) & set(G2.nodes)
    G1_common = G1.subgraph(common_nodes)
    G2_common = G2.subgraph(common_nodes)

    # Find the difference in edges within the common subgraph
    edge_diff = nx.difference(G1_common, G2_common)

    return edge_diff

def has_path(G, source, target):
    try:
        nx.shortest_path(G, source, target)
    except nx.NetworkXException:
        return False
    return True


cmap = plt.get_cmap('tab10')
color_list = cmap(np.linspace(0, 1, 10)).tolist()
color_list = tuple([tuple(c[:-1]) for c in color_list])

# Define Up and Down variables in each cluster
sheet_names = ['Physical exam','Chemistry','Complete Blood Count ']
df_clin = []

for sn in sheet_names:
    df_clin.append( pd.read_excel('/Xenotransplant Clinical data ( POD -3 to POD +51 ).xlsx', sheet_name=sn) )
    if sn in ['Chemistry','Complete Blood Count ','Coagulation','Urine Chemistry']:
        df_clin[-1] = df_clin[-1].drop(0)
    df_clin[-1].iloc[:,0] = df_clin[-1][df_clin[-1].columns[0]].astype(int)
    df_clin[-1] = df_clin[-1][df_clin[-1][df_clin[-1].columns[0]].isin([-3, 7, 13, 20, 26, 33, 51])]
    df_clin[-1] = df_clin[-1].sort_values([df_clin[-1].columns[0], df_clin[-1].columns[1]]).drop_duplicates(subset=df_clin[-1].columns[0])
    df_clin[-1] = df_clin[-1].set_index(df_clin[-1].columns[0]).dropna(axis=1)
df_clin = pd.concat(df_clin, axis=1)
df_clin = df_clin[[c for c in df_clin.columns if 'Date' not in c]]
df_clin = df_clin.drop('Blood Pressure (mmHg )- (Average / Day)', axis=1)
df_clin = df_clin.astype(float)
df_clin = df_clin.T
df_clin.columns = ['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51']

df_mb = pd.read_excel('MC-TB-115_results.xlsx', sheet_name='targeted_data')#.dropna(subset='KEGG_ID')
df_mb = df_mb.pivot_table(index='Metabolite', columns='Sample_ID', values='normalized_area', aggfunc='sum').drop('pig', axis=1)
df_mb = df_mb.dropna(axis=0)
df_mb.columns = ['Pre_Tx_1', 'Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51']
df_mb = df_mb.drop('Pre_Tx_1', axis=1)

adat = somadata.read_adat('HMS-24-042_v5.0_EDTAPlasma.hybNorm.medNormInt.plateScale.calibrate.anmlQC.qcCheck.anmlSMP.20240827.adat')
col_check = adat.pick_meta(axis=1, names=['ColCheck'])#.reset_index()
col_genes = adat.pick_meta(axis=1, names=['EntrezGeneSymbol'])
col_organism = adat.pick_meta(axis=1, names=['Organism'])
df_ss = adat.pick_meta(axis=1, names=['EntrezGeneSymbol']).reset_index()
df_ss = df_ss[df_ss.SampleDescription.str.contains('Visit')]
df_ss = pd.DataFrame(df_ss)
df_ss = df_ss.loc[:,'CRYBB2':].T
df_ss['col_check'] = col_check.columns
df_ss['col_organism'] = col_organism.columns
df_ss = df_ss[(df_ss.col_check=='PASS') & (df_ss.col_organism=='Human')].drop(['col_check','col_organism'], axis=1)
df_ss = df_ss.reset_index().groupby('EntrezGeneSymbol').median()
df_ss = df_ss.T
df_ss = df_ss.iloc[:,np.where(df_ss.columns!='')[0]]
df_ss = df_ss.T
df_ss.columns = ['Pre_Tx','D7', 'D13','D20','D26','D33','D51']

df_fc = pd.concat([df_clin, df_mb, df_ss]).drop_duplicates()
del(df_clin)
del(df_mb)
del(df_ss)


pre_tx_zero_dict = df_fc.loc[df_fc.loc[df_fc.Pre_Tx == 0].min(axis=1).index,:][df_fc.loc[df_fc.loc[df_fc.Pre_Tx == 0].min(axis=1).index,:]>0].min(axis=1)/1000
df_fc.loc[pre_tx_zero_dict.index, 'Pre_Tx'] = pre_tx_zero_dict.values
df_fc = dict(np.log2(df_fc.D51 / df_fc.Pre_Tx))
df_clin_mb_ss['log2(FC)'] = df_clin_mb_ss.index.map(df_fc)


# Positive cluster FDA
df_clin_mb_ss_pos = df_clin_mb_ss[(df_clin_mb_ss.FDA_cluster_clin_mb_ss==0) & (df_clin_mb_ss.type=='proteomics') & (df_clin_mb_ss['log2(FC)']>0.)]

nearestNeighbors_object = NearestNeighbors()
louvain_object = Louvain(random_state=42, resolution=.25)

cl_X = df_clin_mb_ss_pos.loc[:,'Pre_Tx':'D51'].values
nearestNeighbors_object.fit(cl_X)
labels = louvain_object.fit_predict(nearestNeighbors_object.kneighbors_graph())
np.unique(labels, return_counts=True)
df_clin_mb_ss_pos.loc[:,'Louvain'] = labels

## Predict labels to metabolomics and clinical
grid_points = [-3, 7, 13, 20, 26, 33, 51]
fd_train = FDataGrid(df_clin_mb_ss_pos.loc[:,'Pre_Tx':'D51'].values, grid_points)

neigh = KNeighborsClassifier()
neigh.fit(fd_train, labels)

aux = df_clin_mb_ss[(df_clin_mb_ss.FDA_cluster_clin_mb_ss==0) & (df_clin_mb_ss.type!='proteomics')]
fd_predict = FDataGrid(aux.loc[:,'Pre_Tx':'D51'].values, grid_points)
labels_mb_clin = neigh.predict(fd_predict)

print(np.unique(labels_mb_clin, return_counts=True))

aux.loc[:,'Louvain'] = labels_mb_clin

aux = pd.concat([df_clin_mb_ss_pos.loc[:,['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51','Louvain','type']],
                 aux.loc[:,['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51','Louvain','type']]
                ], axis=0)

#Enrichment analysis
enr = dict()
gene_sets = ['c2.cp.reactome.v2024.1.Hs.symbols.gmt',
             'c7.immunesigdb.v2024.1.Hs.symbols.gmt',
             'c5.go.bp.v2024.1.Hs.symbols.gmt',
             'c2.cp.kegg_medicus.v2024.1.Hs.symbols.gmt']
background = df_clin_mb_ss[df_clin_mb_ss.type=='proteomics'].index.tolist()

for l in df_clin_mb_ss_pos.Louvain.unique():
    if l not in enr.keys():
        time1_aux = time.time()
        gene_list =  df_clin_mb_ss_pos[df_clin_mb_ss_pos.Louvain==l].index.tolist()
        enr[l] = gp.enrichr(gene_list=gene_list,
                            gene_sets=gene_sets,
                            organism='homo sapiens',
                            outdir=None,
                            no_plot= True,
                           )       
        enr[l].results[enr[l].results['Adjusted P-value']<=.1].sort_values('Adjusted P-value').to_csv('enrichment_'+str(l)+'_positive_FDA_Louvain_cluster_c2_c5_c7_default_background.csv')
        time2_aux = time.time()
        print(str(l)+'->',int((time2_aux-time1_aux)//60),(time2_aux-time1_aux)%60, sep=':')


# Calclating networks by selected pathways
cmap = plt.get_cmap('gray')
pathways_subcluster_colors = cmap(np.linspace(0, 1, 8)).tolist()
pathways_subcluster_colors = {i:c for i,c in enumerate(pathways_subcluster_colors[1:-1])}
pathways_subcluster_colors[1] = [1.0, 0.6709129411764706, 0.42726470588235294, 1.0]
pathways_subcluster_colors[2] = [0.8588235294117647, 0.8588235294117647, 0.8588235294117647, 1.0]

aux.loc[:,'class_color'] = aux.Louvain.astype(int).map(pathways_subcluster_colors)

node_to_community = dict(aux.Louvain)
node_color = dict(aux.class_color)

nodes2keep = []
# 1
# REACTOME_INNATE_IMMUNE_SYSTEM
nodes2keep.extend("""MANBA;VCL;C8G;EEA1;A1BG;PELI3;TP53;ERP44;MVP;NKIRAS1;IMPDH2;CLEC4C;BIRC2;ACLY;C6;ITLN1;PPBP;APP;SLC27A2;HPSE;PROS1;S100P;CRP;CTSK;PAK3;FCN3;DHX58;ARPC2;MAP2K3;POLR3D;SERPINA1;SERPING1;GALNS;LAT2;TRIM21;C5;GGH;PPP2R1A;TREM2;CFHR5;WASF2;IRF5;TLR4;POLR3E;CASP10;CTSA;ATP6V1F;TIMP2;PGLYRP2;DUSP6;UBE2D1;CD81;RPS6KA1;AGER;MAP3K14;PSMD4;CASP4;VCP;ARHGAP9;STK10;SEMG1;CFHR4;CAP1""".split(';'))
# GOBP_MYELOID_LEUKOCYTE_MIGRATION
nodes2keep.extend("""CXCL12;CXCL2;CXCL3;THBS1;PPBP;MMP2;PDGFB;PF4;DNM1L;CCL5;PDGFD;C5;TREM2;SERPINE1;MMP14;CD81;AGER;PLCB1;CXCL6""".split(';'))
# REACTOME_TOLL_LIKE_RECEPTOR_CASCADES
nodes2keep.extend("""BIRC2;UBE2D1;RPS6KA1;AGER;APP;EEA1;PPP2R1A;CTSK;PELI3;TP53;NKIRAS1;IRF5;TLR4;MAP2K3;DUSP6""".split(';'))
# GOBP_CELLULAR_RESPONSE_TO_INTERLEUKIN_1
nodes2keep.extend("""HIF1A;MMP2;YY1;PLCB1;DAB2IP;INHBB;UPF1;CCL5""".split(';'))

# 2
# REACTOME_INNATE_IMMUNE_SYSTEM
nodes2keep.extend("""NCF4;LGALS3;HCK;SNAP25;DEFB121;ALAD;POLR1D;LY86;CFI;CFHR2;GNLY;VRK3;HSP90AB1;CFH;FGG;PGM2;OPTN;ANXA2;POLR2E;PSMD6;GAA;ISG15;PSMD8;AP2A2;FGR;PSMC3;RNASET2;PELI2;IFNA8;TIRAP;CCL22;RAB31;PPP2R5D;GDI2;GSTP1;RAB24;GYG1;HSP90AA1;MBL2;C4BPA;C3;HP;PSMC1;LCP2;PGAM1;C9;GMFG;CFB;CREBBP""".split(';'))

# 3
# REACTOME_INTERFERON_SIGNALING
nodes2keep.extend("""EIF2S2;MAVS;SPHK1;MX1;TRIM25;IFIT3;TRIM3;KPNA1;PTPN6;PTPN11;IRF6;NCK1;EIF4G1;FLNA;RAF1;IKBKG;IKBKB""".split(';'))
# GOBP_FC_RECEPTOR_SIGNALING_PATHWAY
nodes2keep.extend("""MAPK8;SRC;VAV1;FER;FYN;SYK;IKBKB""".split(';'))

nodes2keep = list(set(nodes2keep))

nodes2highlight = ['CXCL12','CXCL2','CXCL3','CXCL6', #positive 1 GOBP_MYELOID_LEUKOCYTE_MIGRATION
                   'BIRC2','IRF5','TLR4','MAP2K3', #positive 1 REACTOME_TOLL_LIKE_RECEPTOR_CASCADES
                   'HIF1A','MMP2','CCL5', #positive 1 GOBP_CELLULAR_RESPONSE_TO_INTERLEUKIN_1
                   
                   'ISG15','PSMD8','IFNA8','TIRAP','CCL22', # positive 2 REACTOME_INNATE_IMMUNE_SYSTEM
                   
                   'IRF6','IKBKG','IKBKB', # positive 3 REACTOME_INTERFERON_SIGNALING
                   'MAPK8','SRC','VAV1','FER','FYN','SYK','IKBKB', # positive 3 GOBP_FC_RECEPTOR_SIGNALING_PATHWAY
                  ]

nodes2highlight = list(set(nodes2highlight))
nodes2highlight.extend(['Choline', 'L-Kynurenine', 'Adenine', 'Eicosenoic acid'])
nodes2keep.extend(['Choline', 'L-Kynurenine', 'Adenine', 'Eicosenoic acid'])

node_size = dict()
node_alpha = dict()
node_shape = dict()
node_zorder = dict()
node_edge_width = dict()
i = 0
for n in aux.index:
    if n in nodes2highlight:
        node_size[n] = 2.
        node_alpha[n] = 1 
        node_zorder[n] = 3
        node_edge_width[n] = .5
        i+=1
    else:
        node_size[n] = 1.5
        node_alpha[n] = .8
        node_zorder[n] = 2
        node_edge_width[n] = 0.1
        # node_color[n] = [0.4980392156862745, 0.4980392156862745, 0.4980392156862745, .1]

    if n in aux[aux.type=='metabolomics'].index:
        node_shape[n] = '^'
        node_size[n] = 2.5
    elif n in aux[aux.type=='proteomics'].index:
        node_shape[n] = 'o'
    elif n in aux[aux.type=='clinical'].index:
        node_shape[n] = 's'
        node_size[n] = 2.5


grid_points = [-3, 7, 13, 20, 26, 33, 51]
fd_aux = FDataGrid(aux.loc[:,'Pre_Tx':'D51'].values, grid_points)

neigh_aux = KNeighborsClassifier()
neigh_aux.fit(fd_aux, aux['Louvain'].values)

G_pos = nx.from_scipy_sparse_array(neigh_aux.kneighbors_graph())
G_pos = nx.relabel_nodes(G_pos, dict(aux.reset_index()['index']))

T = G_pos.subgraph(set(aux[aux['Louvain'].isin([1,2,3])].index))

nk = set(nodes2highlight)

for s, t in itertools.combinations(nk, 2):
    nk.update(nx.shortest_path(T, source=s, target=t))

T = nx.subgraph(T, nk)

time1_aux = time.time()
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)

g = Graph(T,
          node_color={k:v for k,v in node_color.items() if k in T.nodes},
          node_size={k:v for k,v in node_size.items() if k in T.nodes},
          node_shape={k:v for k,v in node_shape.items() if k in T.nodes},
          node_alpha={k:v for k,v in node_alpha.items() if k in T.nodes},
          node_edge_width={k:v for k,v in node_edge_width.items() if k in T.nodes},
          edge_alpha=0.25,
          edge_width=.5,
          node_zorder={k:v for k,v in node_zorder.items() if k in T.nodes},
          edge_zorder=1,
          node_layout='community',
          node_layout_kwargs=dict(node_to_community={k:v for k,v in node_to_community.items() if k in T.nodes}),
          edge_layout='bundled',
)

texts = []
# adjust_text(texts);
for k,v in g.node_positions.items():
    if k in nodes2highlight:
        if k in aux[aux.type=='metabolomics'].index:
            color = [0.134692, 0.658636, 0.517649, 1.      ]
        elif k in aux[aux.type=='proteomics'].index:
            color = [0.253935, 0.265254, 0.529983, 1.      ]
        elif k in aux[aux.type=='clinical'].index:
            color = '#c78100'
        texts.append(ax.text(v[0],v[1], k, ha='center', va='center', size=10,
                             fontweight='bold',
                             color=color,
                             bbox=dict(facecolor='white', alpha=0.5, pad=0, boxstyle='round')
                            ) 
                    )
        
adjust_text(texts)
plt.show()
time2_aux = time.time()
print(int((time2_aux-time1_aux)//60),(time2_aux-time1_aux)%60, sep=':')



# Negative cluster FDA
df_clin_mb_ss_neg = df_clin_mb_ss[(df_clin_mb_ss.FDA_cluster_clin_mb_ss==2) & (df_clin_mb_ss.type=='proteomics') & (df_clin_mb_ss['log2(FC)']<0.)]

nearestNeighbors_object = NearestNeighbors()
louvain_object = Louvain(random_state=42, resolution=.25)

cl_X = df_clin_mb_ss_neg.loc[:,'Pre_Tx':'D51'].values
nearestNeighbors_object.fit(cl_X)
labels = louvain_object.fit_predict(nearestNeighbors_object.kneighbors_graph())

df_clin_mb_ss_neg.loc[:,'Louvain'] = labels


## Predict labels to metabolomics and clinical
grid_points = [-3, 7, 13, 20, 26, 33, 51]
fd_train = FDataGrid(df_clin_mb_ss_neg.loc[:,'Pre_Tx':'D51'].values, grid_points)

neigh = KNeighborsClassifier()
neigh.fit(fd_train, labels)

aux = df_clin_mb_ss[(df_clin_mb_ss.FDA_cluster_clin_mb_ss==2) & (df_clin_mb_ss.type!='proteomics')]
fd_predict = FDataGrid(aux.loc[:,'Pre_Tx':'D51'].values, grid_points)
labels_mb_clin = neigh.predict(fd_predict)

print(np.unique(labels_mb_clin, return_counts=True))

aux.loc[:,'Louvain'] = labels_mb_clin

aux = pd.concat([df_clin_mb_ss_neg.loc[:,['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51','Louvain','type']],
                 aux.loc[:,['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51','Louvain','type']]
                ], axis=0)

cmap = plt.get_cmap('gray_r')
pathways_subcluster_colors = cmap(np.linspace(0, 1, 8)).tolist()
pathways_subcluster_colors = {i:c for i,c in enumerate(pathways_subcluster_colors[1:-1])}
pathways_subcluster_colors[0] = [0.28627450980392155, 0.28627450980392155, 0.28627450980392155, 1.0]
pathways_subcluster_colors[2] = [0.42745098039215684, 0.42745098039215684, 0.42745098039215684, 1.0]
pathways_subcluster_colors[4] = [0.5725490196078431, 0.5725490196078431, 0.5725490196078431, 1.0]

aux.loc[:,'class_color'] = aux.Louvain.astype(int).map(pathways_subcluster_colors)

node_to_community = dict(aux.Louvain)
node_color = dict(aux.class_color)

## Enrichemnt 
enr = dict()
gene_sets = ['c2.cp.reactome.v2024.1.Hs.symbols.gmt',
             'c7.immunesigdb.v2024.1.Hs.symbols.gmt',
             'c5.go.bp.v2024.1.Hs.symbols.gmt',
             'c2.cp.kegg_medicus.v2024.1.Hs.symbols.gmt']

for l in df_clin_mb_ss_neg.Louvain.unique():
    if l not in enr.keys():
        time1_aux = time.time()
        gene_list =  df_clin_mb_ss_neg[df_clin_mb_ss_neg.Louvain==l].index.tolist()
        enr[l] = gp.enrichr(gene_list=gene_list,
                            gene_sets=gene_sets,
                            organism='homo sapiens',
                            outdir=None,
                            no_plot= True,                            
                           )        
        enr[l].results[enr[l].results['Adjusted P-value']<=.1].sort_values('Adjusted P-value').to_csv('/mnt/d/xeno/enrichment_'+str(l)+'_negative_FDA_Louvain_cluster_c2_c5_c7_default_background.csv')        
        time2_aux = time.time()
        print(str(l)+'->',int((time2_aux-time1_aux)//60),(time2_aux-time1_aux)%60, sep=':')

nodes2keep = []
# 0
nodes2keep.extend("""TNF;EFNB3;EBI3;ITGAL;SMARCE1;ZAP70;IL2RA;IL18R1;PTPN2;TGFB1;CLEC7A;CD55;SOCS6;RIPK3;SLAMF6;ZP4;SMARCD1;HMGB1;SEMA4A;TNFRSF1B;JAML;VCAM1;IGF1;CD1D;SIRPB1;IL4;SH2D2A;BTN3A1;LILRB1;SOCS3;SCRIB;IL12RB1;SOCS1;TNFSF8;ABL2;IFNB1;BCL2;ERBB2;SOD1;LILRB2;BTN2A2;CD4;PLA2G5;CR1;KLRK1;PRKCZ;FCGR2B;CTSG;IL6ST;CD46""".split(';'))
nodes2keep.extend("""EBI3;ZAP70;IL2RA;TGFB1;CD55;RIPK3;ZP4;HMGB1;TNFRSF1B;VCAM1;IL4;CD1D;IGF1;SH2D2A;BTN3A1;LILRB1;SCRIB;IL12RB1;TNFSF8;ERBB2;LILRB2;BTN2A2;CR1;PLA2G5;IL6ST;CD46""".split(';'))
# 2
nodes2keep.extend("""CCR5;CXCL11;PADI2;OSM;SH2B2;CXCL10;EPOR;IL17RC;IL17F;RFFL;TNFRSF13C;IL1A;STAT5B;IL10RA;TNFSF11;IL10;ARG1;IL5RA;TIFA;IL17A;CCR5;SPI1;CX3CL1;LILRB4;IL1RN;TNFRSF11B;XIAP;SH2B3;HSPA1A;CXCL9;STAT2""".split(';'))
# 4
nodes2keep.extend("""EDN1;PLA2G1B;TLR2;KMT5C;TNFAIP3;XCL1;CCL7;FCN1;IL4I1;KIR2DS2;CTSS;RTN4;NCR3;IL1RL1;CD80;RBCK1;HLA-DQA2;GATA1;P2RY12;MMP8;SELE;REG3G;CRTAM;CD19;ADAM17;CD8A;CALR;KIR2DL4;ITGA2B;RBP4;ITCH;LEP;NOS2;ITGAM;CXCL13;MLH1;LAG3;IRF4""".split(';'))
nodes2keep.extend("""PLA2G1B;KMT5C;DAO;XCL1;EVPL;FCN1;GZMB;IL4I1;CEBPG;CSF2RB;NCR3;CD80;GATA1;INS;CRTAM;CD19;ADAM17;CD8A;KIR2DL4;F2;RBP4;NOS2;LEP;ITGAM;MLH1;LAG3;IRF4""".split(';'))

nodes2keep = list(set(nodes2keep))

nodes2highlight = ['ZAP70','IL2RA','IL18R1','IL12RB1','CD4','KLRK1','TNF', #negative 0 T-CELL proliferation
              'CXCL9','STAT2','IL17A','IL17RC','IL17F','CXCL10','CXCL11','CCR5', #negative 2 GOBP_CYTOKINE_MEDIATED_SIGNALING_PATHWAY
              'XCL1','GZMB','CD80','CD19','CD8A','HLA-DQA2','NOS2','ITGAM','CXCL13', #negative 4 GOBP_POSITIVE_REGULATION_OF_IMMUNE_SYSTEM_PROCESS,GOBP_IMMUNE_EFFECTOR_PROCESS
             ]

nodes2highlight = list(set(nodes2highlight))

nodes2highlight.extend(['Glycerophosphocholine','Homocysteine','Indoxyl','L-methionine','Caproic acid','Guanidinoacetic acid','Oxoglutaric acid','Uric Acid'])
nodes2keep.extend(['Glycerophosphocholine','Homocysteine','Indoxyl','L-methionine','Caproic acid','Guanidinoacetic acid','Oxoglutaric acid','Uric Acid'])

nodes2highlight.extend( aux[(aux['Louvain'].isin([0,2,4])) & (aux.type=='clinical')].index.tolist() )
nodes2keep.extend( aux[(aux['Louvain'].isin([0,2,4])) & (aux.type=='clinical')].index.tolist() )


node_size = dict()
node_alpha = dict()
node_shape = dict()
node_zorder = dict()
node_edge_width = dict()
i = 0
for n in aux.index:
    if n in nodes2highlight:
        node_size[n] = 2.
        node_alpha[n] = 1 
        node_zorder[n] = 3
        node_edge_width[n] = .5
        i+=1
    else:
        node_size[n] = 1.5
        node_alpha[n] = .8
        node_zorder[n] = 2
        node_edge_width[n] = 0.1
        # node_color[n] = [0.4980392156862745, 0.4980392156862745, 0.4980392156862745, .1]

    if n in aux[aux.type=='metabolomics'].index:
        node_shape[n] = '^'
        node_size[n] = 2.5
    elif n in aux[aux.type=='proteomics'].index:
        node_shape[n] = 'o'
    elif n in aux[aux.type=='clinical'].index:
        node_shape[n] = 's'
        node_size[n] = 2.5


grid_points = [-3, 7, 13, 20, 26, 33, 51]
fd_aux = FDataGrid(aux.loc[:,'Pre_Tx':'D51'].values, grid_points)

neigh_aux = KNeighborsClassifier()
neigh_aux.fit(fd_aux, aux['Louvain'].values)

G_neg = nx.from_scipy_sparse_array(neigh_aux.kneighbors_graph())
G_neg = nx.relabel_nodes(G_neg, dict(aux.reset_index()['index']))

T = G_neg.subgraph(set(aux[aux['Louvain'].isin([0,2,4])].index))

nk = set(nodes2highlight)

for s, t in itertools.combinations(nk, 2):
    if has_path(T, s, t):
        nk.update(nx.shortest_path(T, source=s, target=t))

T = nx.subgraph(T, nk)


time1_aux = time.time()
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)


g = Graph(T,
          node_color={k:v for k,v in node_color.items() if k in T.nodes},
          node_size={k:v for k,v in node_size.items() if k in T.nodes},
          node_shape={k:v for k,v in node_shape.items() if k in T.nodes},
          node_alpha={k:v for k,v in node_alpha.items() if k in T.nodes},
          node_edge_width={k:v for k,v in node_edge_width.items() if k in T.nodes},
          edge_alpha=0.25,
          edge_width=.5,
          node_zorder={k:v for k,v in node_zorder.items() if k in T.nodes},
          edge_zorder=1,
          node_layout='community',
          node_layout_kwargs=dict(node_to_community={k:v for k,v in node_to_community.items() if k in T.nodes}),
          edge_layout='bundled',
      ax=ax,
)

texts = []
for k,v in g.node_positions.items():
    if k in nodes2highlight:
        if k in aux[aux.type=='metabolomics'].index:
            color = [0.134692, 0.658636, 0.517649, 1.      ]
        elif k in aux[aux.type=='proteomics'].index:
            color = [0.253935, 0.265254, 0.529983, 1.      ]
        elif k in aux[aux.type=='clinical'].index:
            color = '#c78100'
        texts.append(ax.text(v[0],v[1], k, ha='center', va='center', size=10,
                             fontweight='bold', #style='italic', 
                             color=color,
                             bbox=dict(facecolor='white', alpha=0.5, pad=0, boxstyle='round')
                            ) 
                    )
        
adjust_text(texts)
plt.show()
time2_aux = time.time()
print(int((time2_aux-time1_aux)//60),(time2_aux-time1_aux)%60, sep=':')